{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "checkpoint.ipynb",
      "provenance": [],
      "private_outputs": true,
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Pnn4rDWGqDZL"
      },
      "source": [
        "##### Copyright 2018 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "cellView": "form",
        "colab_type": "code",
        "id": "l534d35Gp68G",
        "colab": {}
      },
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "3TI3Q3XBesaS"
      },
      "source": [
        "# Entrenar checkpoints"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "yw_a0iGucY8z"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/guide/checkpoint\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />Ver en TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/es/guide/checkpoint.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Ejecutar en Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/es/guide/checkpoint.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />Ver fuente en GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/es/guide/checkpoint.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Descargar notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4HYmIbxDQOCF",
        "colab_type": "text"
      },
      "source": [
        "Note: Nuestra comunidad de Tensorflow ha traducido estos documentos. Como las traducciones de la comunidad\n",
        "son basados en el \"mejor esfuerzo\", no hay ninguna garantia que esta sea un reflejo preciso y actual \n",
        "de la [Documentacion Oficial en Ingles](https://www.tensorflow.org/?hl=en).\n",
        "Si tienen sugerencias sobre como mejorar esta traduccion, por favor envian un \"Pull request\"\n",
        "al siguiente repositorio [tensorflow/docs](https://github.com/tensorflow/docs).\n",
        "Para ofrecerse como voluntario o hacer revision de las traducciones de la Comunidad\n",
        "por favor contacten al siguiente grupo [docs@tensorflow.org list](https://groups.google.com/a/tensorflow.org/forum/#!forum/docs)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "LeDp7dovcbus"
      },
      "source": [
        "La frase \"Saving a TensorFlow model\" significa tipicamente una de las dos cosas: (1) Checkpoints, O (2) SavedModel.\n",
        "\n",
        "Los Checkpoints capturan el valor exacto de todos los parametros (objetos `tf.Variable`) usados por un modelo. Los Checkpoints no almacenan ninguna descripcion del computo utilizado por el modelo. Por lo mismo, los checkpoints solo son utiles cuando el codigo que usara los parametros almacenados esta disponible.\n",
        "\n",
        "Por otro lado, el formato SavedModel incluye una descripcion serializada del computo definido por el modelo ademas de los valores de los parametros (checkpoint). Con este formato, los modelos son independientes al codigo que creo el mismo. Por ende son idoneos para el despliegue de los modelos a traves de TensorFlos Serving, TensorFlow Lite, TensorFlow.js, o programas en otros lenguajes de programacion (las APIs de TensorFlow para C, C++, Java, Go, Rust, C# etc.)\n",
        "\n",
        "Esta guia cubre las APIs para leer y escribir checkpoints."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "5vsq3-pffo1I"
      },
      "source": [
        "## Guardando de las APIs de entrenamiento de `tf.keras`\n",
        "\n",
        "Pueden referirse a la [ guia para guardar y restaurar de `tf.keras`](https://www.tensorflow.org/tutorials/keras/save_and_restore_models), NOTA: al momento esta en ingles.\n",
        "\n",
        "`tf.keras.Model.save_weights` tambien permite la opcion de guardar en el formato TensorFlow checkpoint. Esta guia explica el formato a mayor detalle e introduce las APIs para administrar los checkpoints en bucles de entrenamiento personalizados."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "XseWX5jDg4lQ"
      },
      "source": [
        "## Definir checkpoints manualmente"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "1jpZPz76ZP3K"
      },
      "source": [
        "El estado persistente de un modelo de TensorFlow es almacenado en objectos `tf.Variable`. Estos objetos pueden ser construidos directamente, pero comunmente con creados mediante APIs de alto nivel tales como `tf.keras.layers`.\n",
        "\n",
        "La manera mas sencilla de admistrar las variables es asociandolas a objetos de Python, y despues referenciando dichos objetos. Las subclases de `tf.train.Checkpoint`, `tf.keras.layers.Layer`, y `tf.keras.Model` rastrean automaticamente las variables asociadas a sus atributos. El ejemplo a contunuacion construye un modelo linear simple, y posteriormente escribe checkpoints que contienen valores para todas las variables del modelo."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "VEvpMYAKsC4z",
        "colab": {}
      },
      "source": [
        "from __future__ import absolute_import, division, print_function, unicode_literals\n",
        "try:\n",
        "  # %tensorflow_version solo existe en Colab.\n",
        "  %tensorflow_version 2.x\n",
        "except Exception:\n",
        "  pass\n",
        "import tensorflow as tf"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "BR5dChK7rXnj",
        "colab": {}
      },
      "source": [
        "class Net(tf.keras.Model):\n",
        "  \"\"\"Un Modelo Linear simple.\"\"\"\n",
        "\n",
        "  def __init__(self):\n",
        "    super(Net, self).__init__()\n",
        "    self.l1 = tf.keras.layers.Dense(5)\n",
        "\n",
        "  def call(self, x):\n",
        "    return self.l1(x)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "fNjf9KaLdIRP"
      },
      "source": [
        "Este ejemplo necesita datos y un paso de optimizacion para poder ser ejecutable aunque esta guia no se trate de esos temas. El modelo entrenara por slices de un dataset en memoria."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "tSNyP4IJ9nkU",
        "colab": {}
      },
      "source": [
        "def toy_dataset():\n",
        "  inputs = tf.range(10.)[:, None]\n",
        "  labels = inputs * 5. + tf.range(5.)[None, :]\n",
        "  return tf.data.Dataset.from_tensor_slices(\n",
        "    dict(x=inputs, y=labels)).repeat(10).batch(2)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "ICm1cufh_JH8",
        "colab": {}
      },
      "source": [
        "def train_step(net, example, optimizer):\n",
        "  \"\"\"Entrena `net` en `example` usando `optimizer`.\"\"\"\n",
        "  with tf.GradientTape() as tape:\n",
        "    output = net(example['x'])\n",
        "    loss = tf.reduce_mean(tf.abs(output - example['y']))\n",
        "  variables = net.trainable_variables\n",
        "  gradients = tape.gradient(loss, variables)\n",
        "  optimizer.apply_gradients(zip(gradients, variables))\n",
        "  return loss"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "NP9IySmCeCkn"
      },
      "source": [
        "El siguiente buckle crea una instancia del modelo y de un optimizer, despues los recolecta en un objeto `tf.train.Checkpoint`. Llama el paso de entrenamiento en un ciclo para cada batch de datos, y escribe periodicamente checkpoints en disco."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "BbCS5A6K1VSH",
        "colab": {}
      },
      "source": [
        "opt = tf.keras.optimizers.Adam(0.1)\n",
        "net = Net()\n",
        "ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net)\n",
        "manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)\n",
        "ckpt.restore(manager.latest_checkpoint)\n",
        "if manager.latest_checkpoint:\n",
        "  print(\"Restaurado de {}\".format(manager.latest_checkpoint))\n",
        "else:\n",
        "  print(\"Inicializando desde cero.\")\n",
        "\n",
        "for example in toy_dataset():\n",
        "  loss = train_step(net, example, opt)\n",
        "  ckpt.step.assign_add(1)\n",
        "  if int(ckpt.step) % 10 == 0:\n",
        "    save_path = manager.save()\n",
        "    print(\"Checkpoint almacenado para el paso {}: {}\".format(int(ckpt.step), save_path))\n",
        "    print(\"loss {:1.2f}\".format(loss.numpy()))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "lw1QeyRBgsLE"
      },
      "source": [
        "El snippet anterior inicializara aleatoriamente las variables del modelo en su primera corrida. Posterior a esta, reanudara el entrenamiendo en donde se quedo:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "UjilkTOV2PBK",
        "colab": {}
      },
      "source": [
        "opt = tf.keras.optimizers.Adam(0.1)\n",
        "net = Net()\n",
        "ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net)\n",
        "manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)\n",
        "ckpt.restore(manager.latest_checkpoint)\n",
        "if manager.latest_checkpoint:\n",
        "  print(\"Restaurado de {}\".format(manager.latest_checkpoint))\n",
        "else:\n",
        "  print(\"Inicializando desde cero\")\n",
        "\n",
        "for example in toy_dataset():\n",
        "  loss = train_step(net, example, opt)\n",
        "  ckpt.step.assign_add(1)\n",
        "  if int(ckpt.step) % 10 == 0:\n",
        "    save_path = manager.save()\n",
        "    print(\"Checkpoint almacenado para el paso {}: {}\".format(int(ckpt.step), save_path))\n",
        "    print(\"loss {:1.2f}\".format(loss.numpy()))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "dxJT9vV-2PnZ"
      },
      "source": [
        "El objeto `tf.train.CheckpointManager` elimina checkpoints viejos. Arriba ha sido configurado para conservar los tres checkpoints mas recientes."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "3zmM0a-F5XqC",
        "colab": {}
      },
      "source": [
        "print(manager.checkpoints)  # Lista los tres checkpoints restantes"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "qwlYDyjemY4P"
      },
      "source": [
        "Estos paths, e.g. `'./tf_ckpts/ckpt-10'`, no son archivos en disco. Son prefijos para un archivo tipo `index` y uno o mas archivos de datos que contienen los valores de las variables. Estos prefijos estan agrupados en un unico archivo de `checkpoint` (`'./tf_ckpts/checkpoint'`) donde el `CheckpointManager` guarda su estado."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "t1feej9JntV_",
        "colab": {}
      },
      "source": [
        "!ls ./tf_ckpts"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "DR2wQc9x6b3X"
      },
      "source": [
        "<a id=\"loading_mechanics\"/>\n",
        "## Mecanica de carga\n",
        "\n",
        "TensorFlow hace coincider las variables con los valores almacenados como *checkpoint* atravesando un grafo dirigido con aristas nombradas, comenzando con el objeto que este siendo cargado. Los nombres de las aristas vienen de los nombres de los atributos en los objetos, por ejemplo el `\"l1\"` en `self.l1 = tf.keras.layers.Dense(5)`. `tf.train.Checkpoint` usa sus argumentos de palabras clave como nombre, tal como el `\"step\"` en `tf.train.Checkpoint(step=...)`.\n",
        "\n",
        "El grafo de dependencias del ejemplo anterior se ve asi:\n",
        "\n",
        "![Visualization of the dependency graph for the example training loop](http://tensorflow.org/images/guide/whole_checkpoint.svg)\n",
        "\n",
        "Con el *optimizer* en rojo, las variables regulares en azul, y variables slot del *optimizer* en naranja. Los otros nodos, por ejemplo el que representa el `tf.train.Checkpoint`, son negros.\n",
        "\n",
        "Las variables *Slot* son parte del estado del *optimizer*, pero son creadas para una variable especifica. Por ejemplo las aristas `'m'` de arriba corresponden a un momentum, los cuales son rastreados por el *Adam's* *optimizer* para cada variable. Las variables *Slot* solo son almacenadas en un *checkpoint* si la variable y el optimizer serian ambas almacenadas, por eso los aristas punteados."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "VpY5IuanUEQ0"
      },
      "source": [
        "La llamada `restore()` de un objeto `tf.train.Checkpoint` hace cola las restauraciones requeridas, restaurando los valores de las variables tan pronto como se encuentre un *path* correspondiente en el objeto `Checkpoint`. Por ejemplo podemos cargar solo el kernel del model que definimos anteriormente recosntruyendo un path a el mediante la red y la capa (layer)."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "wmX2AuyH7TVt",
        "colab": {}
      },
      "source": [
        "to_restore = tf.Variable(tf.zeros([5]))\n",
        "print(to_restore.numpy())  # Puros ceros\n",
        "fake_layer = tf.train.Checkpoint(bias=to_restore)\n",
        "fake_net = tf.train.Checkpoint(l1=fake_layer)\n",
        "new_root = tf.train.Checkpoint(net=fake_net)\n",
        "status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))\n",
        "print(to_restore.numpy())  # Ahora obtenemos el valor restaurado"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "GqEW-_pJDAnE"
      },
      "source": [
        "El grafo de dependencias para estos objetos nuevos es un sub-grafo del checkpoint que escribimos anteriormente. Solo incluye el *bias* y un *save counter* que el `tf.train.Checkpoint` usa para enumerar los *checkpoints*.\n",
        "\n",
        "![Visualization of a subgraph for the bias variable](http://tensorflow.org/images/guide/partial_checkpoint.svg)\n",
        "\n",
        "`restore()` regresa el estado del objeto, que tiene afirmaciones (assertions) opcionales. Todos los objetos que hemos creado en nuestro nuevo `Checkpoint` han sido restaurados, asi que `status.assert_existing_objects_matched()` pasa exitosamente."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "P9TQXl81Dq5r",
        "colab": {}
      },
      "source": [
        "status.assert_existing_objects_matched()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "GoMwf8CFDu9r"
      },
      "source": [
        "Hay muchos objetos en el *checkpoint* que no han sido emparejados, incluyendo el *kernel* de la capa y las variables del optimized. `status.assert_consumed()` solo pasa si el *checkpoint* y el programa empatan exactamente, y arrojara una excepcion en este caso."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "KCcmJ-2j9RUP"
      },
      "source": [
        "### Restauraciones retrasadas\n",
        "\n",
        "Los objetos `Layer` en TensorFlow pueden retrasar la creacion de variables para su primera llamada, cuando las dimensiones de entrada estan disponibles. Por ejemplo, las dimensiones de un *layer kernel* `Dense` dependen tanto de las entradas de la capa como de las dimensiones de salida, y por ende solo la dimension de salida que es requerida como argumento de construccion no es suficiente informacion para la creacion de las variables. Como la llamada a `Layer` tambien lee el valor de la variable, una restauracion debe pasa entre las variables de creacion y su primer uso.\n",
        "\n",
        "Para dar soporte a este idioma, `tf.train.Checkpoint` forma una lista de restauranciones que no tienen una vatiable que empate aun.\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "TXYUCO3v-I72",
        "colab": {}
      },
      "source": [
        "delayed_restore = tf.Variable(tf.zeros([1, 5]))\n",
        "print(delayed_restore.numpy())  # No restaurado; siguen siendo ceros\n",
        "fake_layer.kernel = delayed_restore\n",
        "print(delayed_restore.numpy())  # Restaurado"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "-DWhJ3glyobN"
      },
      "source": [
        "### Revision manual de los checkpoints\n",
        "\n",
        "`tf.train.list_variables` lista las *checkpoint keys* y las dimensiones de las variables en un checkpoint. Las *Checkpoint keys* son los *paths* del grafo mostrado anteriormente."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "RlRsADTezoBD",
        "colab": {}
      },
      "source": [
        "tf.train.list_variables(tf.train.latest_checkpoint('./tf_ckpts/'))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "5fxk_BnZ4W1b"
      },
      "source": [
        "### Rastreo de listas y diccionarios\n",
        "List and dictionary tracking\n",
        "\n",
        "Igual que en las asignaciones directas de atributos, e.g.  `self.l1 = tf.keras.layers.Dense(5)`, la asignacion de listas y diccionarios a atributos rastreara sus contenidos."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "rfaIbDtDHAr_",
        "colab": {}
      },
      "source": [
        "save = tf.train.Checkpoint()\n",
        "save.listed = [tf.Variable(1.)]\n",
        "save.listed.append(tf.Variable(2.))\n",
        "save.mapped = {'one': save.listed[0]}\n",
        "save.mapped['two'] = save.listed[1]\n",
        "save_path = save.save('./tf_list_example')\n",
        "\n",
        "restore = tf.train.Checkpoint()\n",
        "v2 = tf.Variable(0.)\n",
        "assert 0. == v2.numpy()  # No ha sido restaurado aun\n",
        "restore.mapped = {'two': v2}\n",
        "restore.restore(save_path)\n",
        "assert 2. == v2.numpy()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "UTKvbxHcI3T2"
      },
      "source": [
        "Puede notar que existen *wrappers* de objetos para listas y diccionarios. Estos *wrappers* pueden ser incluidos en versiones *checkpoint* de las estructuras de datos subyacientes. Asi como la carga basada en atributos, estos *wrappers* restauran el valor de una variable al momento de ser agregada al contenedor.\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "s0Uq1Hv5JCmm",
        "colab": {}
      },
      "source": [
        "restore.listed = []\n",
        "print(restore.listed)  # ListWrapper([])\n",
        "v1 = tf.Variable(0.)\n",
        "restore.listed.append(v1)  # Restaurar v1, del restore() de la celda anterior\n",
        "assert 1. == v1.numpy()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "OxCIf2J6JyQ8"
      },
      "source": [
        "El mismo rastreo es aplicado automaticamente a subclases de `tf.keras.Model`, y puede ser usado para rastrear listas de capas por ejemplo.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "zGG1tOM0L6iM"
      },
      "source": [
        "## Guardar checkponts basados en objetos con Estimator\n",
        "\n",
        "Ver la [guia a Estimator](https://www.tensorflow.org/guide/estimator). NOTA: documentacion en ingles.\n",
        "\n",
        "Los Estimators guardan checkpoints por default con nombres de variables en lugar de el ogjeto grafo descrito en las secciones anteriores. `tf.train.Checkpoint` aceptara ckeckpoints basadon en nombres, pero los nombres de las variables podrian cambiar al movr las pertes del modelo fuera del `model_fn` del Estimator. Guardar checkpoints basados en objetos facilita el entrenamiento de un modelo dentro de un Estimator y su posterior uso fuera de el.\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "-8AMJeueNyoM",
        "colab": {}
      },
      "source": [
        "import tensorflow.compat.v1 as tf_compat"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "T6fQsBzJQN2y",
        "colab": {}
      },
      "source": [
        "def model_fn(features, labels, mode):\n",
        "  net = Net()\n",
        "  opt = tf.keras.optimizers.Adam(0.1)\n",
        "  ckpt = tf.train.Checkpoint(step=tf_compat.train.get_global_step(),\n",
        "                             optimizer=opt, net=net)\n",
        "  with tf.GradientTape() as tape:\n",
        "    output = net(features['x'])\n",
        "    loss = tf.reduce_mean(tf.abs(output - features['y']))\n",
        "  variables = net.trainable_variables\n",
        "  gradients = tape.gradient(loss, variables)\n",
        "  return tf.estimator.EstimatorSpec(\n",
        "    mode,\n",
        "    loss=loss,\n",
        "    train_op=tf.group(opt.apply_gradients(zip(gradients, variables)),\n",
        "                      ckpt.step.assign_add(1)),\n",
        "    # Decirle al Estimator gue guarde \"ckpt\" en un formato basado en objeto.\n",
        "    scaffold=tf_compat.train.Scaffold(saver=ckpt))\n",
        "\n",
        "tf.keras.backend.clear_session()\n",
        "est = tf.estimator.Estimator(model_fn, './tf_estimator_example/')\n",
        "est.train(toy_dataset, steps=10)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "tObYHnrrb_mL"
      },
      "source": [
        "`tf.train.Checkpoint` puede cargar los checkpoints del Estimator de su `model_dir`."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Q6IP3Y_wb-fs",
        "colab": {}
      },
      "source": [
        "opt = tf.keras.optimizers.Adam(0.1)\n",
        "net = Net()\n",
        "ckpt = tf.train.Checkpoint(\n",
        "  step=tf.Variable(1, dtype=tf.int64), optimizer=opt, net=net)\n",
        "ckpt.restore(tf.train.latest_checkpoint('./tf_estimator_example/'))\n",
        "ckpt.step.numpy()  # De est.train(..., steps=10)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "knyUFMrJg8y4"
      },
      "source": [
        "## Resumen\n",
        "\n",
        "Los objetos de TensorFlow proveen un mecanismo facil y automatico para guardar y restaurar los valores de las variables que usan.\n"
      ]
    }
  ]
}